{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# WebIndexedDataset + Distributed PyTorch Training\n",
    "\n",
    "This notebook illustrates how to use the Web Indexed Dataset (`wids`) library for distributed PyTorch training using `DistributedDataParallel`.\n",
    "\n",
    "Using `wids` results in training code that is almost identical to plain PyTorch, with the only changes being the use of `ShardListDataset` for the dataset construction, and the use of the `DistributedChunkedSampler` for generating random samples from the dataset.\n",
    "\n",
    "`ShardListDataset` requires some local storage. By default, that local storage just grows as shards are downloaded, but if you have limited space, you can run `create_cleanup_background_process` to clean up the cache; shards will be re-downloaded as necessary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "from typing import (\n",
    "    List,\n",
    "    Tuple,\n",
    "    Dict,\n",
    "    Optional,\n",
    "    Any,\n",
    "    Union,\n",
    "    Callable,\n",
    "    Iterable,\n",
    "    Iterator,\n",
    "    NamedTuple,\n",
    "    Set,\n",
    "    Sequence,\n",
    ")\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.distributed as dist\n",
    "from torch.nn.parallel import DistributedDataParallel\n",
    "from torchvision.models import resnet50\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader\n",
    "import ray\n",
    "import wids\n",
    "import dataclasses\n",
    "import time\n",
    "from collections import deque\n",
    "from pprint import pprint\n",
    "\n",
    "\n",
    "def enumerate_report(seq, delta, growth=1.0):\n",
    "    last = 0\n",
    "    count = 0\n",
    "    for count, item in enumerate(seq):\n",
    "        now = time.time()\n",
    "        if now - last > delta:\n",
    "            last = now\n",
    "            yield count, item, True\n",
    "        else:\n",
    "            yield count, item, False\n",
    "        delta *= growth"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Loading for Distributed Training\n",
    "\n",
    "The datasets we use for training are stored in the cloud.\n",
    "We use `fake-imagenet`, which is 1/10th the size of Imagenet\n",
    "and artificially generated, but it has the same number of\n",
    "shards and trains quickly.\n",
    "\n",
    "Note that unlike the `webdataset` library, `wids` always needs\n",
    "a local cache directory (it will use `/tmp` if you don't give it\n",
    "anything explicitly)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# Parameters\n",
    "epochs = 1\n",
    "max_steps = int(1e12)\n",
    "batch_size = 32\n",
    "bucket = \"https://storage.googleapis.com/webdataset/fake-imagenet/\"\n",
    "trainset_url = bucket+\"imagenet-train.json\"\n",
    "valset_url = bucket+\"imagenet-val.json\"\n",
    "cache_dir = \"./_cache\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is a typical PyTorch dataset, except that we read from the cloud.\n",
    "\n",
    "\n",
    "def make_dataset_train():\n",
    "    transform_train = transforms.Compose(\n",
    "        [\n",
    "            transforms.RandomResizedCrop(224),\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    def make_sample(sample):\n",
    "        image = sample[\".jpg\"]\n",
    "        label = sample[\".cls\"]\n",
    "        return transform_train(image), label\n",
    "\n",
    "    trainset = wids.ShardListDataset(trainset_url, cache_dir=\"./_cache\", keep=True)\n",
    "    trainset = trainset.add_transform(make_sample)\n",
    "\n",
    "    return trainset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is really the only thing that is ever so slightly special about the `wids` library:\n",
    "you should use the special `DistributedChunkedSampler` for sampling.\n",
    "\n",
    "The regular `DistributedSampler` will technically work, but because of its poor locality\n",
    "of reference, will be significantly slower."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To keep locality of reference in the dataloader, we use a special sampler\n",
    "# for distributed training, DistributedChunkedSampler.\n",
    "\n",
    "\n",
    "def make_dataloader_train():\n",
    "    dataset = make_dataset_train()\n",
    "    sampler = wids.DistributedChunkedSampler(dataset, chunksize=1000, shuffle=True)\n",
    "    dataloader = DataLoader(\n",
    "        dataset, batch_size=batch_size, sampler=sampler, num_workers=4\n",
    "    )\n",
    "    return dataloader\n",
    "\n",
    "\n",
    "def make_dataloader(split=\"train\"):\n",
    "    \"\"\"Make a dataloader for training or validation.\"\"\"\n",
    "    if split == \"train\":\n",
    "        return make_dataloader_train()\n",
    "    elif split == \"val\":\n",
    "        return make_dataloader_val()\n",
    "    else:\n",
    "        raise ValueError(f\"unknown split {split}\")\n",
    "\n",
    "\n",
    "# Try it out.\n",
    "sample = next(iter(make_dataloader()))\n",
    "print(sample[0].shape, sample[1].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PyTorch Distributed Training Code\n",
    "\n",
    "Really, all that's needed for distributed training is the `DistributedDataParallel` wrapper around the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For convenience, we collect all the configuration parameters into\n",
    "# a dataclass.\n",
    "\n",
    "\n",
    "@dataclasses.dataclass\n",
    "class Config:\n",
    "    rank: Optional[int] = None\n",
    "    epochs: int = 1\n",
    "    max_steps: int = int(1e18)\n",
    "    lr: float = 0.001\n",
    "    momentum: float = 0.9\n",
    "    world_size: int = 8\n",
    "    backend: str = \"nccl\"\n",
    "    master_addr: str = \"localhost\"\n",
    "    master_port: str = \"12355\"\n",
    "    report_s: float = 15.0\n",
    "    report_growth: float = 1.1\n",
    "\n",
    "\n",
    "Config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A typical PyTorch training function.\n",
    "\n",
    "\n",
    "def train(config):\n",
    "    # Define the model, loss function, and optimizer\n",
    "    model = resnet50(pretrained=False).cuda()\n",
    "    if config.rank is not None:\n",
    "        model = DistributedDataParallel(model)\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)\n",
    "\n",
    "    # Data loading code\n",
    "    trainloader = make_dataloader(split=\"train\")\n",
    "\n",
    "    losses, accuracies, steps = deque(maxlen=100), deque(maxlen=100), 0\n",
    "\n",
    "    # Training loop\n",
    "    for epoch in range(config.epochs):\n",
    "        for i, data, verbose in enumerate_report(trainloader, config.report_s):\n",
    "            inputs, labels = data[0].cuda(), data[1].cuda()\n",
    "\n",
    "            # zero the parameter gradients\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # forward + backward + optimize\n",
    "            outputs = model(inputs)\n",
    "            loss = loss_fn(outputs, labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            # just bookkeping and progress report\n",
    "            steps += len(labels)\n",
    "            accuracy = (\n",
    "                (outputs.argmax(1) == labels).float().mean()\n",
    "            )  # calculate accuracy\n",
    "            losses.append(loss.item())\n",
    "            accuracies.append(accuracy.item())\n",
    "            if verbose and len(losses) > 0:\n",
    "                avgloss = sum(losses) / len(losses)\n",
    "                avgaccuracy = sum(accuracies) / len(accuracies)\n",
    "                print(\n",
    "                    f\"rank {config.rank} epoch {epoch:5d}/{i:9d} loss {avgloss:8.3f} acc {avgaccuracy:8.3f} {steps:9d}\",\n",
    "                    file=sys.stderr,\n",
    "                )\n",
    "            if steps > config.max_steps:\n",
    "                print(\n",
    "                    \"finished training (max_steps)\",\n",
    "                    steps,\n",
    "                    config.max_steps,\n",
    "                    file=sys.stderr,\n",
    "                )\n",
    "                return\n",
    "\n",
    "    print(\"finished Training\", steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A quick smoke test.\n",
    "\n",
    "os.environ[\"GOPEN_VERBOSE\"] = \"1\"\n",
    "config = Config()\n",
    "config.epochs = 1\n",
    "config.max_steps = 1000\n",
    "train(config)\n",
    "os.environ[\"GOPEN_VERBOSE\"] = \"0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Distributed Training in Ray\n",
    "\n",
    "The code above can be used with any distributed computing framwork, including `torch.distributed.launch`.\n",
    "\n",
    "Below is simply an example of how to launch the training jobs with the Ray framework. Ray is nice\n",
    "for distributed training because it makes the Python code independent of the runtime environment\n",
    "(Kubernetes, Slurm, ad-hoc networking, etc.). Meaning, the code below will work regardless of how\n",
    "you start up your Ray cluster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The distributed training function to be used with Ray.\n",
    "# Since this is started via Ray remote, we set up the distributed\n",
    "# training environment here.\n",
    "\n",
    "\n",
    "@ray.remote(num_gpus=1)\n",
    "def train_in_ray(rank, config):\n",
    "    if rank is not None:\n",
    "        # Set up distributed PyTorch.\n",
    "        config.rank = rank\n",
    "        os.environ[\"MASTER_ADDR\"] = config.master_addr\n",
    "        os.environ[\"MASTER_PORT\"] = config.master_port\n",
    "        dist.init_process_group(\n",
    "            backend=config.backend, rank=rank, world_size=config.world_size\n",
    "        )\n",
    "        # Ray will automatically set CUDA_VISIBLE_DEVICES for each task.\n",
    "    train(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not ray.is_initialized():\n",
    "    ray.init()\n",
    "print(\"#gpus available in the cluster\", ray.available_resources()[\"GPU\"])\n",
    "\n",
    "\n",
    "def distributed_training(config):\n",
    "    num_gpus = ray.available_resources()[\"GPU\"]\n",
    "    config.world_size = int(min(config.world_size, num_gpus))\n",
    "    pprint(config)\n",
    "    results = ray.get(\n",
    "        [train_in_ray.remote(i, config) for i in range(config.world_size)]\n",
    "    )\n",
    "    print(results)\n",
    "\n",
    "\n",
    "config = Config()\n",
    "config.epochs = epochs\n",
    "config.max_steps = max_steps\n",
    "config.batch_size = batch_size\n",
    "print(config)\n",
    "distributed_training(config)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
